//
// Created by song on 16-11-18.
//

#ifndef C0COMPILER_ASTTOIRINTERPRETER_H
#define C0COMPILER_ASTTOIRINTERPRETER_H


#include "../front/ASTNode.h"
#include "../front/SymbolTable.h"
#include "../front/ASTVisitor.h"
#include "IR.h"


class ASTtoIRInterpreter: public ASTVisitor {
private:
    ASTNode& ast;
    SymbolTable& symTable;
    IRList* irList;

public:
    ASTtoIRInterpreter(ASTNode& node, SymbolTable& table): ast(node), symTable(table)
    {
        firstFun = false;
        irList = new IRList();
        this->postVisitSet = {
            SourceCode,
            FunctionDeclar,
            IfStmt,
            WhileStmt,
            SwitchStmt,
            CaseSegment,
            DefaultSegment,
            AssignStmt,
            ReturnStmt,
            RelationExpr,
            AddSubExpr,
            MultiDivExpr,
            ArraySubscript,
            FunctionInvoke,
            ConditionExpr
        };
        this->defaultAction = true;
    }

    IRList* eval() {
        startTraverse(ast);
        return irList;
    }

protected:
    bool firstFun;
    map<ASTNode*, map<string, TmpVar*>> properties;
    TmpVar* setProperty(ASTNode& node, string key, TmpVar* var){
        if(properties.count(&node)==0){
            properties[&node] = map<string, TmpVar*>();
        }
        map<string, TmpVar *> &m = properties[&node];
        m[key] = var;
        return var;
    }

    TmpVar* getProperty(ASTNode& node, string key){
        if(properties.count(&node)>0){
            map<string, TmpVar *> &m = properties[&node];
            if(m.count(key)>0){
                return m[key];
            }else{
                Error::nextErrorDetail << "[ast2IR] getProperty failed: property "<< key <<"not found on node " << node.toString();
                Error::internal(Error::Should_Not_Happen);
            }
        }else{
            Error::internal(Error::Should_Not_Happen);
        }
    }

    TmpVar* checkConvertLabel;
    TmpVar* errorLabel;
    void int2CharConvertCheck(TmpVar* t){
        appendIR(NULL, PUSHPARAM, t, NULL, NULL);
        appendIR(NULL, CALL, checkConvertLabel, NULL, NULL);
    }

    void appendIR(TmpVar* label, IRType op, TmpVar* left, TmpVar* right, TmpVar* result){
        irList->append(IR::B(label, op, left, right, result));
    }

    virtual bool constDeclar(ASTNode &node) override {
        SymbolTable::VarRecord* v = symTable.getVar(node.id);
        TmpVar* value, *name;
        if(v->type==SymbolTable::CHAR){
            value = TmpVar::num(node.ch);
        }else{
            value = TmpVar::num(node.value);
        }
        name = TmpVar::var(v->name);
        v->ref = name;
        appendIR(NULL, ASSIGN, value, NULL, name);
        return false;
    }

    virtual bool varDeclar(ASTNode &node) override {
        SymbolTable::VarRecord* v = symTable.getVar(node.id);
        TmpVar* value, *name;
        value = TmpVar::num(0);
        name = TmpVar::var(v->name);
        v->ref = name;
        appendIR(NULL, ASSIGN, value, NULL, name);
        return false;
    }

    virtual bool arrayDeclar(ASTNode &node) override {
        SymbolTable::VarRecord* v = symTable.getVar(node.id);
        TmpVar* value, *name;
        value = TmpVar::num(v->size);
        name = TmpVar::var(v->name);
        v->ref = name;
        if(v->isGlobal){
            appendIR(NULL, ARRAY, value, NULL, name);
        }else{
            appendIR(NULL, PUSHARR, value, NULL, name);
        }
        return false;
    }

    virtual bool funDeclar(ASTNode &node, bool postVisit) override {
        SymbolTable::FunRecord* f = symTable.getFun(node.id);
        if(!firstFun){
            firstFun = true;
            TmpVar* programBegin = getProperty(*node.parent, "programBegin");
            appendIR(TmpVar::label("functions"), JMP, NULL, NULL, programBegin);
        }
        if(!postVisit){
            TmpVar* name = TmpVar::label(f->name);
            f->ref = name;
            appendIR(name, FUNBEGIN, NULL, NULL, NULL);
            if(f->returnType==SymbolTable::VOID && f->name=="main"){
                setProperty(*node.parent, "mainFun", name);
            }
        }else{
            int len = (int) node.children.size();
            if(f->returnType==SymbolTable::VOID){
                if(node.children[len-1]->astType!=ReturnStmt){
                    appendIR(NULL, RETURN, NULL, NULL, NULL);
                }
            }else{ // function not return value. runtime error.
                if(node.children[len-1]->astType!=ReturnStmt){
                    appendIR(NULL, ERROR, TmpVar::num(0), NULL, NULL);
                }
            }
            appendIR(NULL, FUNEND, NULL, NULL, NULL);
        }
        return true;
    }

    virtual bool param(ASTNode &node) override {
        SymbolTable::VarRecord* v = symTable.getVar(node.id);
        v->ref = TmpVar::var(v->name);
        appendIR(NULL, POPPARAM, NULL, NULL, v->ref);
        return false;
    }

    virtual bool relationExpr(ASTNode &node, bool postVisit) override {
        if(postVisit){
            TmpVar* left = getProperty(* (node.children[0]), "result");
            TmpVar* right = getProperty(* (node.children[1]), "result");
            IRType t;
            switch(node.contentType) {
                case Equal:
                    t = SE;
                    break;
                case NotEqual:
                    t = SNE;
                    break;
                case LessThan:
                    t = SL;
                    break;
                case LessEqual:
                    t = SLE;
                    break;
                case GreatThan:
                    t = SG;
                    break;
                case GreatEqual:
                    t = SGE;
                    break;
                default:
                    Error::internal(Error::Should_Not_Happen);
            }
            TmpVar* result = TmpVar::var();
            setProperty(node, "result", result);
            appendIR(NULL, t, left, right, result);
        }
        return true;
    }

    virtual bool addSubExpr(ASTNode &node, bool postVisit) override {
        if(postVisit){
            TmpVar* result = TmpVar::var();
            TmpVar* left = getProperty(* (node.children[0]), "result");
            TmpVar* right = getProperty(* (node.children[1]), "result");
            IRType t;
            switch(node.contentType) {
                case Add:
                    t = ADD;
                    break;
                case Minus:
                    t = SUB;
                    break;
                default:
                    Error::internal(Error::Should_Not_Happen);
            }
            setProperty(node, "result", result);
            appendIR(NULL, t, left, right, result);
            if(node.needCheck){
                int2CharConvertCheck(result);
            }
            TmpVar* end = TmpVar::label("end");
            if(t==ADD){
//                appendIR(NULL, J, left, right, result);
            }
        }
        return true;
    }

    virtual bool multiDivExpr(ASTNode &node, bool postVisit) override {
        if(postVisit){
            TmpVar* result = TmpVar::var();
            TmpVar* left = getProperty(* (node.children[0]), "result");
            TmpVar* right = getProperty(* (node.children[1]), "result");
            IRType t;
            switch(node.contentType) {
                case Mul:
                    t = MUL;
                    break;
                case Div:
                    t = DIV;
                    break;
                default:
                    Error::internal(Error::Should_Not_Happen);
            }
            setProperty(node, "result", result);
            appendIR(NULL, t, left, right, result);
            if(node.needCheck){
                int2CharConvertCheck(result);
            }
        }
        return true;
    }

    virtual bool functionInvoke(ASTNode &node, bool postVisit) override {
        if(postVisit){
            if(node.id==0){ // printf
                for (int i=0; i< node.children.size(); i++) {
                    ASTNode & child = * node.children[i];
                    SymbolTable::VarRecord* v;
                    SymbolTable::FunRecord* childFun;
                    IRType op;
                    switch (child.astType){
                        case StringConst:
                            op = WRITESTR;
                            break;
                        case CharConst:
                            op = WRITECHAR;
                            break;
                        case Identifier:
                            v = symTable.getVar(child.id);
                            if(v->type==SymbolTable::CHAR){
                                op = WRITECHAR;
                            }else{
                                op = WRITE;
                            }
                            break;
                        case FunctionInvoke:
                            childFun = symTable.getFun(child.id);
                            if(childFun->returnType==SymbolTable::CHAR){
                                op = WRITECHAR;
                            }else{
                                op = WRITE;
                            }
                            break;
                        default:
                            op = WRITE;
                    }
                    appendIR(NULL, op, getProperty(child, "result"), NULL, NULL);
                }
                appendIR(NULL, WRITELN, NULL, NULL, NULL);
            }else if(node.id==1){ // scanf
                vector<ASTNode *>::const_iterator iter;
                for (iter = node.children.begin();
                     iter != node.children.end(); iter++) {
                    ASTNode& child = **iter;
                    SymbolTable::VarRecord* v = symTable.getVar(child.id);
                    IRType op;
                    if(v->type==SymbolTable::CHAR){
                        op = READCHAR;
                    }else{
                        op = READ;
                    }
                    appendIR(NULL, op, NULL, NULL, v->ref);
                }
            }else{
                vector<ASTNode *>::const_iterator iter;
                for (iter = node.children.begin();
                     iter != node.children.end(); iter++) {
                    TmpVar* exprValue = getProperty(**iter, "result");
                    appendIR(NULL, PUSHPARAM, exprValue, NULL, NULL);
                }
                SymbolTable::FunRecord* f = symTable.getFun(node.id);
                if(f->returnType==SymbolTable::VOID){
                    appendIR(NULL, CALL, f->ref, NULL, NULL);
                }else{
                    TmpVar* result = TmpVar::var();
                    setProperty(node, "result", result);
                    appendIR(NULL, CALL, f->ref, NULL, result);
                    if(node.needCheck){
                        int2CharConvertCheck(result);
                    }
                }
            }
        }
        return true;
    }

    virtual bool condition(ASTNode &node, bool postVisit) override {
        if(postVisit){
            TmpVar* result = getProperty(*(node.children[0]), "result");
            setProperty(node, "result", result);
//            if(node.parent->astType==IfStmt){
//                TmpVar* elsePart = getProperty(* node.parent, "else");
//                TmpVar* ifFalse = TmpVar::num(0);
//                appendIR(NULL, JE, result, ifFalse, elsePart);
//            }else if(node.parent->astType==SwitchStmt){ // switch
//                setProperty(node, "result", result);
//            }else { // while
//                TmpVar* endOfWhile = getProperty(* node.parent, "endOfWhile");
//                TmpVar* ifFalse = TmpVar::num(0);
//                appendIR(NULL, JE, result, ifFalse, endOfWhile);
//            }
        }
        return true;
    }

    virtual bool ifStmt(ASTNode &node, bool postVisit) override {
        if(!postVisit && node.children.size()>=2){
            walkAST(*node.children[0]);// condition expression
            TmpVar* result = getProperty(*node.children[0],"result");
            TmpVar* ifFalse = TmpVar::num(0);
            TmpVar* end = TmpVar::label("endOfIf");
            if(node.children.size()==3){ // has else
                TmpVar* elsePart = TmpVar::label("else");
                appendIR(NULL, JE, result, ifFalse, elsePart);
                walkAST(*node.children[1]); // then part
                appendIR(NULL, JMP, NULL, NULL, end);
                appendIR(elsePart, NOOP, NULL, NULL, NULL);
                walkAST(*node.children[2]); // else part
                appendIR(end, NOOP, NULL, NULL, NULL);
            }else{ // no else
                setProperty(node, "gotoWhenFalse", end);
                appendIR(NULL, JE, result, ifFalse, end);
                walkAST(*node.children[1]); // then part
                appendIR(end, NOOP, NULL, NULL, NULL);
            }
        }
        return false;
    }

    virtual bool switchStmt(ASTNode &node, bool postVisit) override {
        if(!postVisit){
            TmpVar* endOfSwitch = TmpVar::label("endOfSwitch");
            TmpVar* condResult;
            for (int i=0; i<node.children.size(); i++) { // fist loop to get all conditions.
                ASTNode& child = *node.children[i];
                if(i==0){ // expression
                    walkAST(child);
                    condResult = getProperty(child, "result");
                }else{
                    if(child.astType==CaseSegment){
                        TmpVar* caseValue;
                        if(child.contentType==Char){
                            caseValue = TmpVar::num(child.ch);
                        }else{
                            caseValue = TmpVar::num(child.value);
                        }
                        TmpVar* caseBegin = setProperty(child, "Case", TmpVar::label("Case"));
                        appendIR(NULL, JE, condResult, caseValue, caseBegin);
                    }else{ // default segment
                        TmpVar* caseBegin = setProperty(child, "Case", TmpVar::label("Default"));
                        appendIR(NULL, JMP, NULL, NULL, caseBegin);
                    }
                }
            }

            for (int i=1; i<node.children.size(); i++) { // second loop of cases' code.
                ASTNode& child = *node.children[i];
                appendIR(getProperty(child, "Case"), NOOP, NULL, NULL, NULL);
                walkAST(child);
                appendIR(NULL, JMP, NULL, NULL, endOfSwitch);
            }
            // end of switch
            appendIR(endOfSwitch, NOOP, NULL, NULL, NULL);
        }
        return false;
    }

    // need do nothing because we've done everything in switchStmt.
    virtual bool caseStmt(ASTNode &node, bool postVisit) override {
//        if(!postVisit){
//            TmpVar* caseBegin = TmpVar::label("Case");
//            setProperty(node, "Case", caseBegin);
//            if(node.contentType==Char){
//                setProperty(node, "Value", TmpVar::num(node.ch));
//            }else{
//                setProperty(node, "Value", TmpVar::num(node.value));
//            }
//            appendIR(caseBegin, NOOP, NULL, NULL, NULL);
//        }else{
//            TmpVar* endOfSwitch = getProperty(*(node.parent), "endOfSwitch");
//            appendIR(NULL, JMP, NULL, NULL, endOfSwitch);
//        }
        return true;
    }
// need do nothing because we've done everything in switchStmt.
    virtual bool defaultStmt(ASTNode &node, bool postVisit) override {
//        if(!postVisit){
//            TmpVar* caseBegin = TmpVar::label("Case");
//            setProperty(node, "Case", TmpVar::label("Case"));
//            appendIR(caseBegin, NOOP, NULL, NULL, NULL);
//        }else{
//            TmpVar* endOfSwitch = getProperty(*(node.parent), "endOfSwitch");
//            appendIR(NULL, JMP, NULL, NULL, endOfSwitch);
//        }
        return true;
    }

    virtual bool whileStmt(ASTNode &node, bool postVisit) override {
        if(!postVisit && node.children.size()>1){
            TmpVar* whileCondTest = TmpVar::label("whileBeginTest");
            appendIR(whileCondTest, NOOP, NULL, NULL, NULL);
            walkAST(*node.children[0]);
            TmpVar* result = getProperty(*node.children[0],"result");
            TmpVar* ifFalse = TmpVar::num(0);
            TmpVar* endOfWhile = TmpVar::label("endOfWhile");
            appendIR(NULL, JE, result, ifFalse, endOfWhile);
            for(int i=1; i<node.children.size(); i++){
                walkAST(*node.children[i]);
            }
            appendIR(NULL, JMP, NULL, NULL, whileCondTest);
            appendIR(endOfWhile, NOOP, NULL, NULL, NULL);
        }
        return false;
    }

    virtual bool arraySubscript(ASTNode &node, bool postVisit) override {
        if(postVisit){
            SymbolTable::VarRecord* arr = symTable.getVar(node.id);
            TmpVar* arrayVar = arr->ref;
            TmpVar* index = getProperty(*node.children[0], "result");
            if(node.parent->astType==AssignStmt &&
               node.parent->children[0] == &node){ // left value;
                setProperty(node, "result", index);
                setProperty(node, "arrName", arrayVar);
            }else{ // right value
                TmpVar* result = TmpVar::var();
                appendIR(NULL, DEFREF, arrayVar, index, result);
                setProperty(node, "result", result);
                if(node.needCheck){
                    int2CharConvertCheck(result);
                }
            }
        }
        return true;
    }

    virtual bool assignStmt(ASTNode &node, bool postVisit) override {
        if(postVisit){
            TmpVar* left = getProperty(*node.children[0], "result");
            TmpVar* right = getProperty(*node.children[1], "result");
            if(node.children[0]->astType==ArraySubscript){
                TmpVar* arrName = getProperty(*node.children[0], "arrName");
                appendIR(NULL, ASSIGNARR, arrName, left, right);
            }else{
                appendIR(NULL, ASSIGN, right, NULL, left);
            }
        }
        return true;
    }

    virtual bool returnStmt(ASTNode &node, bool postVisit) override {
        if(postVisit){
            if(node.children.size()>0){
                appendIR(NULL, RETURN, getProperty(*node.children[0], "result"), NULL, NULL);
            }else{
                appendIR(NULL, RETURN, NULL, NULL, NULL);
            }
        }
        return true;
    }

    virtual bool identifier(ASTNode &node) override {
        SymbolTable::VarRecord* v = symTable.getVar(node.id);
        setProperty(node, "result", v->ref);
        if(node.needCheck){
            int2CharConvertCheck(v->ref);
        }
        return true;
    }

    virtual bool sourceCode(ASTNode &node, bool postVisit) override {
        if(!postVisit){
            errorLabel = TmpVar::label("throwErrorExit");
            checkConvertLabel = TmpVar::label("checkConvert");
            setProperty(node, "programBegin", TmpVar::label("programBegin"));
        }else{
            TmpVar* program = getProperty(node, "programBegin");
            TmpVar* main = getProperty(node, "mainFun");
            appendIR(program, CALL, main, NULL, NULL);
//            generate check code of int to char.
            TmpVar* endOfProgram = TmpVar::label("endOfProgram");
            appendIR(NULL, JMP, NULL, NULL, endOfProgram);
            TmpVar* value = TmpVar::var();
            TmpVar* check57 = TmpVar::label("check57");
            TmpVar* check90 = TmpVar::label("check90");
            TmpVar* check122 = TmpVar::label("check122");
            TmpVar* error = TmpVar::label("error");
            TmpVar* end = TmpVar::label("endOfCheck");
            appendIR(checkConvertLabel, FUNBEGIN, NULL, NULL, NULL);
            appendIR(NULL, POPPARAM, NULL, NULL, value);
            appendIR(NULL, JE, value, TmpVar::num(42), end);//*
            appendIR(NULL, JE, value, TmpVar::num(43), end);//+
            appendIR(NULL, JE, value, TmpVar::num(45), end);//-
            appendIR(NULL, JGE, value, TmpVar::num(47), check57);//div/
            appendIR(NULL, JMP, NULL, NULL, error);
            appendIR(check57,JLE, value, TmpVar::num(57), end);// <=57?
            appendIR(NULL, JGE, value, TmpVar::num(65), check90);// >=65?
            appendIR(NULL, JMP, NULL, NULL, error);
            appendIR(check90,JLE, value, TmpVar::num(90), end); // <=90?
            appendIR(NULL, JGE, value, TmpVar::num(97), check122);//>=97?
            appendIR(NULL, JMP, NULL, NULL, error);
            appendIR(check122, JLE, value, TmpVar::num(122), end);//<=122?
            appendIR(error, ERROR, TmpVar::num(1), NULL, NULL);
            appendIR(end, RETURN, NULL, NULL, NULL);
            appendIR(NULL, FUNEND, NULL, NULL, NULL);
            appendIR(endOfProgram, NOOP, NULL, NULL, NULL);
        }
        return true;
    }

    virtual bool number(ASTNode &node) override {
        setProperty(node, "result", TmpVar::num(node.value));
        return false;
    }

    virtual bool charConst(ASTNode &node) override {
        setProperty(node, "result", TmpVar::num(node.ch));
        return false;
    }

    virtual bool stringConst(ASTNode &node) override {
        setProperty(node, "result", TmpVar::str(node.content));
        return false;
    }

};


#endif //C0COMPILER_INTERPRETER_H
